from typing import List, Dict, Optional
from pathlib import Path
import json
import random
import time
from tqdm import tqdm
from openai import OpenAI
import logging
import re

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class Config:
    SAMPLE_SIZE = 300
    API_RATE_LIMIT = 2
    MAX_RETRIES = 3

    BASE_PATH = Path("/process_COT/strategyqa")
    INPUT_FILE = BASE_PATH / "reasoning_output_strategyqa_300.txt"
    OUTPUT_FILE = BASE_PATH / "reasoning_output_strategyqa_300_correct.txt"
    PROGRESS_FILE = BASE_PATH / "progress.json"

    API_KEY = ""
    BASE_URL = ""
    MODEL_NAME = "gpt-4o"

class QuestionProcessor:
    def __init__(self):
        self.client = OpenAI(
            api_key=Config.API_KEY,
            base_url=Config.BASE_URL
        )
        self.processed_questions = self._load_progress()

    def _load_progress(self) -> set:
        if Config.PROGRESS_FILE.exists():
            with open(Config.PROGRESS_FILE, 'r') as f:
                return set(json.load(f))
        return set()

    def _save_progress(self, question_id: str):
        self.processed_questions.add(question_id)
        with open(Config.PROGRESS_FILE, 'w') as f:
            json.dump(list(self.processed_questions), f)

    def get_completion(self, prompt: str, retries: int = Config.MAX_RETRIES) -> Optional[str]:
        for attempt in range(retries):
            try:
                response = self.client.chat.completions.create(
                    model=Config.MODEL_NAME,
                    messages=[
                        {"role": "system", "content": "You are a helpful assistant that provides step-by-step reasoning."},
                        {"role": "user", "content": prompt}
                    ]
                )
                return response.choices[0].message.content.strip()
            except Exception as e:
                logging.warning(f"API call failed (attempt {attempt + 1}/{retries}): {str(e)}")
                if attempt < retries - 1:
                    time.sleep(2 ** attempt)
                else:
                    logging.error(f"API call final failure: {str(e)}")
                    return None

    def parse_input_file(self, input_file: Path) -> List[Dict]:
        with open(input_file, 'r', encoding='utf-8') as f:
            content = f.read()
        blocks = [block.strip() for block in content.split('\n\n') if block.strip()]
        qa_list = []
        for idx, block in enumerate(blocks):
            q = re.search(r"Question:\s*(.*)", block)
            fct = re.search(r"Facts:\s*(.*)", block)
            ans = re.search(r"Answer:\s*(.*)", block)
            if q and fct and ans:
                qa_list.append({
                    "id": f"strategyqa_{idx}",
                    "question": q.group(1).strip(),
                    "facts": fct.group(1).strip(),
                    "answer": ans.group(1).strip()
                })
        return qa_list

    def generate_prompt(self, qa_item: Dict) -> str:
        return (
            "You are a helpful assistant who provides step-by-step reasoning for questions. "
            "Given the question, relevant facts, and a reference answer, generate a clear, logical, and detailed reasoning process that naturally leads to the reference answer. "
            "Make sure your reasoning is accurate, uses the facts, and ends with the answer in the required format.\n\n"
            f"Question: {qa_item['question']}\n"
            f"Facts: {qa_item['facts']}\n"
            f"Reference Answer: {qa_item['answer']}\n\n"
            "Format your response exactly as:\n"
            "Let's think step by step.\n"
            "[Your step-by-step reasoning that naturally leads to the reference answer]\n"
            "The answer is [reference answer]."
        )

    def format_response(self, qa_item: Dict, response: str) -> str:
        lines = []
        lines.append(f"Question: {qa_item['question']}")
        lines.append(f"Facts: {qa_item['facts']}")
        in_reasoning = False
        reasoning_lines = []
        for line in response.strip().split('\n'):
            line = line.strip()
            if not line:
                continue
            if line.startswith("Let's think"):
                in_reasoning = True
                lines.append(line)
            elif line.startswith("The answer is"):
                in_reasoning = False
                if reasoning_lines:
                    lines.extend(reasoning_lines)
                lines.append(line)
            elif in_reasoning:
                reasoning_lines.append(line)
        return '\n'.join(lines)

    def process_questions(self, qa_list: List[Dict]):
        Config.OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
        selected_questions = random.sample(qa_list, min(Config.SAMPLE_SIZE, len(qa_list)))
        with open(Config.OUTPUT_FILE, 'a', encoding='utf-8') as f:
            for qa in tqdm(selected_questions, desc="Processing questions"):
                if qa['id'] in self.processed_questions:
                    logging.info(f"Skipping already processed question {qa['id']}")
                    continue
                prompt = self.generate_prompt(qa)
                if response := self.get_completion(prompt):
                    formatted_output = self.format_response(qa, response)
                    f.write(f"{formatted_output}\n\n")
                    self._save_progress(qa['id'])
                time.sleep(Config.API_RATE_LIMIT)

def main():
    try:
        processor = QuestionProcessor()
        qa_list = processor.parse_input_file(Config.INPUT_FILE)
        logging.info(f"loaded {len(qa_list)} questions")
        processor.process_questions(qa_list)
        logging.info("Processing completed!")
    except Exception as e:
        logging.error(f"Program error: {str(e)}")

if __name__ == "__main__":
    main()